# CLEAN WORKSPACE AND LOAD PACKAGES --------------------------------------------

rm(list = ls())
library(datasim)
library(tidyverse)

# SIMULATE MULTIVARIATE SPATIAL DATA -------------------------------------------

set.seed(3)
Corr <- matrix(c(1, -0.3, 0, -0.3, 1, 0.3, 0, 0.3, 1), nrow = 3)
sigmas <- rep(0.3, 3)
D <- diag(sigmas)
Cov <- D %*% Corr %*% D

# beta <- c(-0.5, 0, 0.5)
beta <- c(0, 0, 0)
variance <- 0.3 * matrix(c(1, 0, 0, 0, 1, 0, 0, 0, 1), nrow = 3)
cor.model <- "exp_cor"
cor.params <- list(list(phi = 0.04), list(phi = 0.04), list(phi = 0.1))

f <- list(
  mean ~ mfe(x1, beta = get("beta")) +
    mre(factor(id), sigma = get("Cov")) +
    mgp(list(s1), variance = get("variance"), cor.model = get("cor.model"),
        cor.params = get("cor.params")),
  sd ~ I(0)
  )

n <- 300
m <- 3
(data_geo <- sim_model(formula = f, n = n, responses = m))
## # A tibble: 900 x 9
##       id      x1    s1 mre.factor.mean mgp.list.mean    mean    sd
##    <int>   <dbl> <dbl>           <dbl>         <dbl>   <dbl> <dbl>
##  1     1 -1.29   0.168         0.00419       -0.0457 -0.0415     0
##  2     2  2.64   0.808        -0.490          0.336  -0.154      0
##  3     3  0.487  0.385         0.147         -0.549  -0.402      0
##  4     4  0.854  0.328        -0.306         -0.0600 -0.366      0
##  5     5  1.09   0.602         0.327         -0.228   0.0994     0
##  6     6  0.226  0.604        -0.342         -0.291  -0.633      0
##  7     7  0.0682 0.125        -0.00472        0.433   0.429      0
##  8     8 -0.985  0.295         0.0892        -0.845  -0.756      0
##  9     9 -1.31   0.578         0.960         -0.0223  0.938      0
## 10    10  2.46   0.631         0.0268         0.817   0.843      0
## # ... with 890 more rows, and 2 more variables: response <dbl>,
## #   response_label <int>
# knitr::kable(head(data_model, 10))

# VISUALIZE MULTIVARIATE SPATIAL DATA ------------------------------------------

ggplot(data_geo, aes(x1, response)) +
  geom_smooth(aes(col = factor(response_label))) +
  geom_point(aes(col = factor(response_label)))
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'

ggplot(data_geo, aes(s1, mgp.list.mean)) +
  geom_line(aes(col = factor(response_label)))

data_geo %>%
  dplyr::select(id, mre.factor.mean, response_label) %>%
  spread(response_label, mre.factor.mean) %>%
  dplyr::select(-id) %>%
  GGally::ggpairs(aes(fill = "any"))

data_geo_wide <- data_geo %>%
  dplyr::rename(ability = response, id_person = id) %>%
  gather(var, value, mre.factor.mean:ability) %>%
  mutate(var = paste0(var, response_label)) %>%
  select(-response_label) %>%
  spread(var, value)


# SIMULATE ITEM FACTOR DATA ----------------------------------------------------

q <- 10
init_data <- purrr::map(1:q, ~ data_geo_wide) %>%
  purrr::reduce(rbind)

# n <- 300
difficulty <- matrix((1:q - 5)/10 * 2, nrow = 1)
discrimination1 <- seq(0.4, 1.5, length.out = q)
discrimination2 <- runif(q, 0, 2)
discrimination3 <- runif(q, 0, 2)
discrimination1[1] <- 1
discrimination2[1:2] <- c(0, 1)
discrimination3[1:3] <- c(0, 0, 1)
# discrimination1 <- discrimination1 * 0.3
# discrimination2 <- discrimination2 * 0.3
cbind(discrimination1, discrimination2, discrimination3)
##       discrimination1 discrimination2 discrimination3
##  [1,]       1.0000000       0.0000000      0.00000000
##  [2,]       0.5222222       1.0000000      0.00000000
##  [3,]       0.6444444       1.3146662      1.00000000
##  [4,]       0.7666667       0.5407459      1.31471017
##  [5,]       0.8888889       1.7505244      0.36458530
##  [6,]       1.0111111       0.2192177      0.15044220
##  [7,]       1.1333333       1.0466811      0.03345367
##  [8,]       1.2555556       1.3502482      0.12519519
##  [9,]       1.3777778       0.4543743      0.25297449
## [10,]       1.5000000       1.3584900      0.37496552
f <- list(
  prob ~ mfa(ones, beta = get("difficulty")) +
    mfe(ability1, beta = get("discrimination1")) +
    mfe(ability2, beta = get("discrimination2")),
  # + mfe(ability3, beta = get("discrimination3")),
  size ~ I(1)
  )

data_long <- sim_model(formula = f,
                        link_inv = list(pnorm, identity),
                        generator = rbinom,
                        responses = q,
                        n = n,
                        init_data = init_data
                        )

data_long <- dplyr::rename(data_long, subject = id,
                           item = response_label, y = response)

# VISUALIZE ITEM FACTOR DATA ---------------------------------------------------

explor <- data_long %>%
  group_by(subject) %>%
  summarize(endorse = mean(y),
            ability1 = unique(ability1),
            ability2 = unique(ability2),
            # ability3 = unique(ability3),
            x1 = unique(x1))
ggplot(explor, aes(ability1, endorse)) + geom_point(alpha = 0.5)

ggplot(explor, aes(ability2, endorse)) + geom_point(alpha = 0.5)

# ggplot(explor, aes(ability3, endorse)) + geom_point(alpha = 0.5)
# ggplot(explor, aes(x1, endorse)) + geom_point(alpha = 0.5)

Rcpp::sourceCpp("../src/mirt-gibss.cpp")
source("../R/ggplot-mcmc.R")
iter <- 10000
system.time(samples <- ifa_gibbs(data_long$y, n, q, iter, 2))
##    user  system elapsed 
## 337.422 119.141 114.447
samples_tib <- as_tibble.spmirt.list(samples, iter/2)
summary(samples_tib)
## # A tibble: 3,630 x 6
##    Parameters `2.5%`   `10%`   `50%`  `90%` `97.5%`
##    <fct>       <dbl>   <dbl>   <dbl>  <dbl>   <dbl>
##  1 V1         -1.44  -0.953  -0.0924  0.783   1.29 
##  2 V2         -0.844 -0.384   0.539   1.47    1.99 
##  3 V3         -2.24  -1.68   -0.765   0.125   0.649
##  4 V4         -1.82  -1.34   -0.494   0.375   0.883
##  5 V5         -1.82  -1.33   -0.466   0.404   0.929
##  6 V6         -1.22  -0.805   0.0169  0.917   1.40 
##  7 V7         -0.923 -0.369   0.668   1.63    2.14 
##  8 V8         -2.76  -2.25   -1.28   -0.334   0.202
##  9 V9         -0.322  0.228   1.34    2.33    2.92 
## 10 V10        -0.414  0.0143  0.892   1.84    2.41 
## # ... with 3,620 more rows
samples_long <- gather(samples_tib)

as_tibble.spmirt.list(samples, 0, 10, "c") %>%
  gg_trace(alpha = 0.6)

as_tibble.spmirt.list(samples, 0, 10, "a") %>%
  gg_trace(alpha = 0.6)

as_tibble.spmirt.list(samples, iter/2, 10, "a") %>%
  gg_density(alpha = 0.5, ridges = TRUE, aes(fill = Parameters), scale = 4)
## Picking joint bandwidth of 0.0718

as_tibble.spmirt.list(samples, iter/2, 10, "theta") %>%
  dplyr::select(1:100) %>%
  gg_density(alpha = 0.5, ridges = TRUE, aes(fill = Parameters), scale = 4)
## Picking joint bandwidth of 0.184

as_tibble.spmirt.list(samples, 0, 10, "theta") %>%
  select(1:10) %>%
  gg_trace(alpha = 0.6)

as_tibble.spmirt.list(samples, 0, 10, "a") %>%
  gg_density2d(`Discrimination 1`, `Discrimination 2`, each = 10,
               keys = c("Item ", "Discrimination "),
               highlight = c(discrimination1, discrimination2))
## Warning: Computation failed in `stat_density2d()`:
## bandwidths must be strictly positive

as_tibble.spmirt.list(samples, 0, 10, "a") %>%
  gg_scatter(`Discrimination 1`, `Discrimination 2`, each = 10,
               keys = c("Item ", "Discrimination "),
               highlight = c(discrimination1, discrimination2))

as_tibble.spmirt.list(samples, iter/ 2, select = "a") %>%
  summary() %>%
  mutate(param = c(discrimination1, discrimination2)) %>%
  gg_errorbarh() +
  geom_point(aes(param, Parameters), col = 3)

as_tibble.spmirt.list(samples, iter/2, select = "c") %>%
  summary() %>%
  mutate(param = as.numeric(difficulty)) %>%
  gg_errorbarh() +
  geom_point(aes(param, Parameters), col = 3)

as_tibble.spmirt.list(samples, iter/2, select = "theta") %>%
  dplyr::select(1:300) %>%
  summary() %>%
  mutate(param = data_geo$response[1:300]) %>%
  gg_errorbarh(sorted = TRUE) +
  geom_point(aes(x = param), col = 3)

as_tibble.spmirt.list(samples, iter/2, select = "theta") %>%
  dplyr::select(301:600) %>%
  summary() %>%
  mutate(param = data_geo$response[301:600]) %>%
  gg_errorbarh(sorted = TRUE) +
  geom_point(aes(x = param), col = 3)

ability1_pred <- as_tibble.spmirt.list(samples, iter/2, select = "theta") %>%
  dplyr::select(1:300) %>%
  summary() %>%
  mutate(param = data_geo$response[1:300],
         s1 = data_geo$s1[1:300],
         s2 = s1,
         estim = `50%`)
ability1_pred %>%
    ggplot(aes(s1, `50%`)) +
    geom_line() +
    geom_line(aes(s1, param, col = "real"))

vg <- gstat::variogram(estim ~ 1, ~ s1 + s2, ability1_pred, cutoff = 1, width = 0.01)
ggplot(vg, aes(dist, gamma)) +
  geom_point(aes(size = np)) +
  geom_smooth() +
  expand_limits(y = 0, x = 0) +
  scale_x_continuous(limits = c(0, 0.7))
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'
## Warning: Removed 30 rows containing non-finite values (stat_smooth).
## Warning: Removed 30 rows containing missing values (geom_point).

ability2_pred <- as_tibble.spmirt.list(samples, iter/2, select = "theta") %>%
  dplyr::select(301:600) %>%
  summary() %>%
  mutate(param = data_geo$response[301:600],
         s1 = data_geo$s1[301:600],
         s2 = s1,
         estim = `50%`)
ability2_pred %>%
  ggplot(aes(s1, `50%`)) +
  geom_line() +
  geom_line(aes(s1, param, col = "real"))

vg <- gstat::variogram(estim ~ 1, ~ s1 + s2, ability2_pred, cutoff = 1, width = 0.01)
ggplot(vg, aes(dist, gamma)) +
  geom_point(aes(size = np)) +
  geom_smooth() +
  expand_limits(y = 0, x = 0) +
  scale_x_continuous(limits = c(0, 0.7))
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'
## Warning: Removed 30 rows containing non-finite values (stat_smooth).

## Warning: Removed 30 rows containing missing values (geom_point).

# # PREPARE DATA FOR MODELLING ---------------------------------------------------
#
# Y <- data_model %>% dplyr::select(id, response, response_label) %>%
#   spread(response_label, response) %>%
#   arrange(id) %>%
#   dplyr::select(-id) %>%
#   as.matrix()
#
# X <- data_model %>% dplyr::select(id, matches("^x[[:digit:]]+$")) %>%
#   unique() %>%
#   arrange(id) %>%
#   dplyr::select(-id) %>%
#   as.matrix()
#
# Beta <- matrix(beta, nrow = 1)
# Sigma_proposal <- diag(1, 3)
#
# # RUN MODEL --------------------------------------------------------------------
#
# getwd()
# Rcpp::sourceCpp("../src/multi-lm.cpp")
# source("../R/ggplot-mcmc.R")
#
# iter <- 10^6
# system.time(
#   samples <- multi_lm(Y, X, iter, 0.01 * Sigma_proposal, 0.001 * Sigma_proposal)
# )
# samples %>% map(~ tail(.))
#
# # apply(samples$beta, 2, mean)
# # cor(samples$beta)
#
# # Visualize traces
# as_tibble(samples, 0, 100, select = "beta") %>%
#   gg_trace(wrap = TRUE, alpha = 0.6)
#
# as_tibble(samples, 0, 100, select = "beta") %>% gg_trace(alpha = 0.6)
# as_tibble(samples, 0, 100, select = "corr_chol") %>% gg_trace(alpha = 0.6)
# as_tibble(samples, 0, 100, select = "corr") %>% gg_trace(alpha = 0.6)
# as_tibble(samples, 0, 100, select = "sigmas") %>% gg_trace(alpha = 0.6)
#
# bla <- as_tibble(samples, iter/2, select = "sigmas")
# cov(log(bla))
# nrow(unique(bla)) / nrow(bla)
#
# bla <- as_tibble(samples, iter/2, select = "corr_chol")
# cov(bla)
# nrow(unique(bla)) / nrow(bla)
#
# # Visualize densities
#
# as_tibble(samples, iter / 2, select = "corr_chol") %>%
#   gg_density(aes(fill = Parameters), scale = 2, alpha = 0.5, ridges = TRUE)
#
# as_tibble(samples, iter / 2, select = "corr") %>%
#   gg_density(aes(fill = Parameters), scale = 1, alpha = 0.5, ridges = TRUE)
#
# # Visualize credible intervals
# as_tibble(samples, iter / 2, select = "beta") %>%
#   summary() %>%
#   mutate(param = beta) %>%
#   gg_errorbarh() +
#   geom_point(aes(param, Parameters), col = 3)
#
# Corr_chol <- t(chol(Corr))
# corr_chol <- Corr_chol[lower.tri(Corr_chol, diag = TRUE)]
# corr <- Corr[lower.tri(Corr)]
#
# as_tibble(samples, iter / 2, select = "corr_chol") %>%
#   summary() %>%
#   mutate(param = corr_chol) %>%
#   gg_errorbarh() +
#   geom_point(aes(param, Parameters), col = 3)
#
# as_tibble(samples, iter / 2, select = "corr") %>%
#   summary() %>%
#   mutate(param = corr) %>%
#   gg_errorbarh() +
#   geom_point(aes(param, Parameters), col = 3)
#
#
# as_tibble(samples, iter / 2 ,select = "sigmas") %>%
#   summary() %>%
#   mutate(param = sigmas) %>%
#   gg_errorbarh() +
#   geom_point(aes(param, Parameters), col = 3)
#
#
# # Visualize credible intervals for all Parameters
# as_tibble(samples, iter / 2) %>%
#   summary() %>%
#   mutate(param = c(beta, corr_chol, corr, sigmas)) %>%
#   gg_errorbar() +
#   geom_point(aes(Parameters, param), col = 3)
#